%% DESCRIPTION

% This script generates stimulus-triggered averages of cortical evoked responses.
% You can analyze a single file or combine two files together to geneate the average responses.
% Individual neural sweeps are baseline-corrected by subtracting the mean of the PRE_SWEEP
% from all the samples in the sweep.
% The script outputs the EP-peak values (i.e., the maximum deflection from the baseline) 
% of the individual sweeps into the specified output CSV file (for statistical analysis).
% The script also generates plots of the stimulus triggers (Figure 1), individual neural sweeps (for a 
% specified channel; Figure 2), and the average evoked responses for the specified channels (Figures 3, 4, & 5).

%% DEFINE VARIABLES
baseFilenames = [
    "20171206_002"
];

for n = 1 : length(baseFilenames)

    fprintf('Processing %d of %d files\n', n, length(baseFilenames));

    
    baseFilename = baseFilenames{n};
    BASEDIR = '/Users/samiramoorjani/Documents/Yuri/';
    DATADIR = [BASEDIR 'Data/'];
    outputPath = [BASEDIR 'Output/' baseFilename '_test.csv']; % Specify name for the output CSV file
    figureOutputPathPrefix = [BASEDIR 'Figures_BCofSweeps/' baseFilename '_test_']; % This is where the figures will be stored
    ns4Filename = [baseFilename '.ns4'];

    fprintf('%s\n', ns4Filename);
    
    % Specify PRE and POST in milliseconds
    PRE = 20;
    POST = 100;

    AMPLIFICATION = 4; % Max. digital value / Max. analog value

    % Enter experiment information
    EP = 'CCEP';
    CURRENT = 120; % Enter stimulation current in �A 
    SITE = 'LM1(18,18)D/S'; % Enter site of stimulation
    PHASE = 'Preconditioning'; % Enter experimental phase
    LATENCY = 2.7; % Enter EP-peak latency in milliseconds

    % Select 'Combine Files' or 'Load Files.' If using 'Combine Files,' specify the two filenames. 
    % Verify Outliers 
    % Verify list of channels to be plotted
    % Verify Y-axis limits
    % Verify line color for the conditioning mEP plot
    % Specify 'w' or 'a' for output file

    % Note that not-connected channels are 21, 23, 30, 32, 94, and 96

    %% COMBINE FILES

    % Specify the two files that need to be combined together
    % ns4Filename1 = '20180215_001.ns4';
    % ns4Filename2 = '20180215_001.ns5';
    % 
    % % Combine the two files together into a single NSx structure in MATLAB
    % [NS4] = combineNSx([DATADIR ns4Filename1], [DATADIR ns4Filename2]);

    %% LOAD FILES

    % If the .ns4 file has already been loaded, don't re-load the file
    if exist('filenameExtracted', 'var') ~= 1 || strcmp(filenameExtracted, "") || ~strcmp(filenameExtracted, ns4Filename)
        fprintf('Loading NS4 file: %s\n', ns4Filename);
        NS4 = openNSx('report', 'read', [DATADIR ns4Filename], 'p:double');
        if exist('NS4', 'var') ~= 1 || isstruct(NS4) ~= 1
            filenameExtracted = "";
            return;
        end
        filenameExtracted = ns4Filename;
    else
        fprintf('Loading already-loaded NS4 file: %s\n', filenameExtracted);
    end

    %% EXTRACT STIMULUS TRIGGERS

    fprintf('Extracting stimulus-trigger locations\n');

    % Specify location of stimulus triggers
    triggerData = NS4.Data(end, :);

    % Baseline correction
    triggerData = triggerData - mean(triggerData); 

    % Stop underscore from being subcripted in the printed filenames
    underscoreStartIndex = strfind(ns4Filename, '_');
    prefix_ns4Filename = ns4Filename(1 : underscoreStartIndex - 1);
    postfix_ns4Filename = ns4Filename(underscoreStartIndex + 1 : end);
    printNs4Filename = [prefix_ns4Filename '\_' postfix_ns4Filename];

    % Find local maxima
    [ peaks, triggerLocations ] = findpeaks(double(triggerData));

    % Remove noise peaks
    PEAK_THRESHOLD = 1000;
    triggerLocations = triggerLocations(peaks > PEAK_THRESHOLD);

    % Specify sampling frequency
    samplingFrequency = NS4.MetaTags.SamplingFreq;

    % Convert PRE and POST to samples
    PRE_SWEEP = fix(PRE * samplingFrequency / 1000);
    POST_SWEEP = fix(POST * samplingFrequency / 1000);

    sweepLength = PRE_SWEEP + POST_SWEEP; 

    % Remove triggers that are too close together
    for t = 1 : length(triggerLocations) - 1

        if triggerLocations(t + 1) - triggerLocations(t) < sweepLength
           triggerLocations(t) = NaN;
%            triggerLocations(t + 1) = NaN;
        end

    end

    % Remove first and last triggers due to missing PRE_SWEEP or POST_SWEEP
    % triggerLocations(1) = NaN;
%     triggerLocations(length(triggerLocations)) = NaN;

    % Remove outliers
    triggerLocations(230) = NaN;

    correctedTriggerLocations = triggerLocations(not(isnan(triggerLocations)));

    % Plot triggerData
    figure
    plot(triggerData);
    title(sprintf('Graph of file ''%s''; Total number of triggers: %d, Corrected number of triggers: %d', printNs4Filename, length(triggerLocations), length(correctedTriggerLocations)))
    grid on
    savefig([figureOutputPathPrefix '1.fig'])

    %% GENERATE NEURAL SWEEPS

    % Specify location of neural channels 
    neuralChannelLocations = [1 : 20     22      24 : 32     65 : 84     86      88 : 96];
    channelNames = {NS4.ElectrodesInfo(neuralChannelLocations).Label};

    % Convert EP-peak latency to samples
    PEAK_LATENCY = PRE_SWEEP + fix(LATENCY * samplingFrequency / 1000);

    triggerCount = length(correctedTriggerLocations);
    neuralChannelCount = length(neuralChannelLocations);   

    neuralSweeps = nan(sweepLength, triggerCount, neuralChannelCount);
    evokedPotentialPeaks = nan(1, triggerCount, neuralChannelCount);

    fprintf('Processing neural data\n');

    for n = 1 : neuralChannelCount 
        neuralChannelData = NS4.Data(neuralChannelLocations(n), :);   

        for t = 1 : triggerCount  
            sweepStart = correctedTriggerLocations(t) - PRE_SWEEP;
            sweepEnd = sweepStart + sweepLength - 1;
            sweep = neuralChannelData(sweepStart : sweepEnd);

            % Baseline correction of individual sweeps by subtracting the mean of the first 150 samples in the PRE_SWEEP
            baselineCorrectionFactor = mean(sweep(1, 1 : 150), 2); 
            sweep = sweep(1, :) - baselineCorrectionFactor; 

            neuralSweeps(:, t, n) = sweep;     
            evokedPotentialPeaks(:, t, n) = sweep(1, PEAK_LATENCY);
        end

    end  

    % Plot neuralSweeps for a specified channel
    channelId = 22;

    figure
    plot(neuralSweeps(:, :, channelId))
    ylim([-6000  6000])
    title(sprintf('Graph of file ''%s'', Site ''%s''', printNs4Filename, channelNames{channelId}))
    grid on
    savefig([figureOutputPathPrefix '2.fig'])

    %% OUTPUT EP-PEAK VALUES OF INDIVIDUAL SWEEPS
    channelEvokedPotentialPeaks = evokedPotentialPeaks(:, :, channelId) / AMPLIFICATION; % Conversion of raw values to �V
    meanPeakValue = mean(channelEvokedPotentialPeaks); 

    % Write EP-peak values into a CSV file
    outputFile = fopen(outputPath, 'w'); % Opens output file in write mode
    % outputFile = fopen(outputPath, 'a'); % Opens output file in write mode. Appends data to the end of the file.
    fprintf(outputFile, '%0.4f\n', channelEvokedPotentialPeaks);
    fclose(outputFile);

    %% PLOT EVOKED RESPONSES

    mEPBefore = mean(neuralSweeps, 2); % mean(A,2) returns a column vector containing the mean of the elements in each row
    mEP = squeeze(mEPBefore); % Removes singleton dimensions to make it a 2D matrix with data for each channel appearing in a separate column

    fprintf('Generating plots\n');

    % Specify left and right channels to be plotted
    leftChannelId1 = [  1       2       5       6         9       10 ]; 
    leftChannelId2 = [  13      14      18      21     23  ];
    leftChannelId3 = [  24 : 26      27 : 30  ]; % Note 24 is LM1(14,15)-D and 26 is LM1(14,15)-I.
    rightChannelId1 = [  35 : 36      39 : 40   43 : 44 ]; 
    rightChannelId2 = [  47 : 48     51 : 54  ];
    rightChannelId3 = [  55 : 59  ]; % Removed RSMA(2,20)-I and 31 : 32. Confirm that these are non-working channels.
    
    % Specify channel to study effect of conditioning
    conditioningChannelId = 22; 

    leftChannelCount1 = length(leftChannelId1);
    leftChannelCount2 = length(leftChannelId2);
    leftChannelCount3 = length(leftChannelId3);
    
    rightChannelCount1 = length(rightChannelId1);
    rightChannelCount2 = length(rightChannelId2);
    rightChannelCount3 = length(rightChannelId3);

    % Plot left channels
    meanOfLeftPreSweep1 = nan(1, leftChannelCount1);
    meanOfLeftPreSweep2 = nan(1, leftChannelCount2);
    meanOfLeftPreSweep3 = nan(1, leftChannelCount3);

    meanOfLeftBaseline1 = nan(1, leftChannelCount1);
    stdOfLeftBaseline1 = nan(1, leftChannelCount1);
    twoStdAboveLeftBaseline1 = nan(1, leftChannelCount1);
    
    meanOfLeftBaseline2 = nan(1, leftChannelCount2);
    stdOfLeftBaseline2 = nan(1, leftChannelCount2);
    twoStdAboveLeftBaseline2 = nan(1, leftChannelCount2);
    
    meanOfLeftBaseline3 = nan(1, leftChannelCount3);
    stdOfLeftBaseline3 = nan(1, leftChannelCount3);
    twoStdAboveLeftBaseline3 = nan(1, leftChannelCount3);
    
    figure
    hold on

    for l1 = 1 : leftChannelCount1 
        timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
        timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
        singleChannelMEP = mEP(:, leftChannelId1(l1)); 
        meanOfLeftPreSweep1(:, l1) = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact

        meanOfLeftBaseline1(:, l1) = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
        stdOfLeftBaseline1(:, l1) = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
        twoStdAboveLeftBaseline1(:, l1) = meanOfLeftBaseline1(:, l1) + 2 * stdOfLeftBaseline1(:, l1);
        
        plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', channelNames{leftChannelId1(l1)})) % Conversion of raw values to �V
        ylim([-50 50])
        xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
        ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
        title(sprintf('Graph of file ''%s''; Recorded ''%s'' at %0.1f �A delivered to ''%s''; Number of triggers: %d', printNs4Filename, EP, CURRENT, SITE, triggerCount))
        legend('-DynamicLegend')
    end

    hold off

    savefig([figureOutputPathPrefix '3.fig'])

    figure
    hold on

    for l2 = 1 : leftChannelCount2 
        timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
        timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
        singleChannelMEP = mEP(:, leftChannelId2(l2)); 
        meanOfLeftPreSweep2(:, l2) = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact

        meanOfLeftBaseline2(:, l2) = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
        stdOfLeftBaseline2(:, l2) = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
        twoStdAboveLeftBaseline2(:, l2) = meanOfLeftBaseline2(:, l2) + 2 * stdOfLeftBaseline2(:, l2);
    
        plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', channelNames{leftChannelId2(l2)})) % Conversion of raw values to �V
        ylim([-50 50])
        xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
        ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
        title(sprintf('Graph of file ''%s''; Recorded ''%s'' at %0.1f �A delivered to ''%s''; Number of triggers: %d', printNs4Filename, EP, CURRENT, SITE, triggerCount))
        legend('-DynamicLegend')
    end

    hold off

    savefig([figureOutputPathPrefix '4.fig'])

    figure
    hold on

    for l3 = 1 : leftChannelCount3 
        timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
        timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
        singleChannelMEP = mEP(:, leftChannelId3(l3)); 
        meanOfLeftPreSweep3(:, l3) = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact

        meanOfLeftBaseline3(:, l3) = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
        stdOfLeftBaseline3(:, l3) = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
        twoStdAboveLeftBaseline3(:, l3) = meanOfLeftBaseline3(:, l3) + 2 * stdOfLeftBaseline3(:, l3);
        
        plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', channelNames{leftChannelId3(l3)})) % Conversion of raw values to �V
        ylim([-50 50])
        xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
        ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
        title(sprintf('Graph of file ''%s''; Recorded ''%s'' at %0.1f �A delivered to ''%s''; Number of triggers: %d', printNs4Filename, EP, CURRENT, SITE, triggerCount))
        legend('-DynamicLegend')
    end

    hold off

    savefig([figureOutputPathPrefix '5.fig'])
    
    % Plot right channels
    meanOfRightPreSweep1 = nan(1, rightChannelCount1);
    meanOfRightPreSweep2 = nan(1, rightChannelCount2);
    meanOfRightPreSweep3 = nan(1, rightChannelCount3);
    
    meanOfRightBaseline1 = nan(1, rightChannelCount1);
    stdOfRightBaseline1 = nan(1, rightChannelCount1);
    twoStdAboveRightBaseline1 = nan(1, rightChannelCount1);
    
    meanOfRightBaseline2 = nan(1, rightChannelCount2);
    stdOfRightBaseline2 = nan(1, rightChannelCount2);
    twoStdAboveRightBaseline2 = nan(1, rightChannelCount2);
    
    meanOfRightBaseline3 = nan(1, rightChannelCount3);
    stdOfRightBaseline3 = nan(1, rightChannelCount3);
    twoStdAboveRightBaseline3 = nan(1, rightChannelCount3);
    
    figure 
    hold on

    for r1 = 1 : rightChannelCount1
        timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
        timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
        singleChannelMEP = mEP(:, rightChannelId1(r1)); 
        meanOfRightPreSweep1(:, r1) = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact

        meanOfRightBaseline1(:, r1) = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
        stdOfRightBaseline1(:, r1) = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
        twoStdAboveRightBaseline1(:, r1) = meanOfRightBaseline1(:, r1) + 2 * stdOfRightBaseline1(:, r1);
        
        plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', channelNames{rightChannelId1(r1)})) % Conversion of raw values to �V
        ylim([-50 50])
        xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
        ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
        title(sprintf('Graph of file ''%s''; Recorded ''%s'' at %0.1f �A delivered to ''%s''; Number of triggers: %d', printNs4Filename, EP, CURRENT, SITE, triggerCount))
        legend('-DynamicLegend')
    end

    hold off

    savefig([figureOutputPathPrefix '6.fig'])

    figure 
    hold on

    for r2 = 1 : rightChannelCount2
        timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
        timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
        singleChannelMEP = mEP(:, rightChannelId2(r2)); 
        meanOfRightPreSweep2(:, r2) = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact

        meanOfRightBaseline2(:, r2) = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
        stdOfRightBaseline2(:, r2) = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
        twoStdAboveRightBaseline2(:, r2) = meanOfRightBaseline2(:, r2) + 2 * stdOfRightBaseline2(:, r2);
        
        plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', channelNames{rightChannelId2(r2)})) % Conversion of raw values to �V
        ylim([-50 50])
        xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
        ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
        title(sprintf('Graph of file ''%s''; Recorded ''%s'' at %0.1f �A delivered to ''%s''; Number of triggers: %d', printNs4Filename, EP, CURRENT, SITE, triggerCount))
        legend('-DynamicLegend')
    end

    hold off

    savefig([figureOutputPathPrefix '7.fig'])

    
    figure 
    hold on

    for r3 = 1 : rightChannelCount3
        timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
        timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
        singleChannelMEP = mEP(:, rightChannelId3(r3)); 
        meanOfRightPreSweep3(:, r3) = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact

        meanOfRightBaseline3(:, r3) = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
        stdOfRightBaseline3(:, r3) = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
        twoStdAboveRightBaseline3(:, r3) = meanOfRightBaseline3(:, r3) + 2 * stdOfRightBaseline3(:, r3);
        
        plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', channelNames{rightChannelId3(r3)})) % Conversion of raw values to �V
        ylim([-50 50])
        xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
        ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
        title(sprintf('Graph of file ''%s''; Recorded ''%s'' at %0.1f �A delivered to ''%s''; Number of triggers: %d', printNs4Filename, EP, CURRENT, SITE, triggerCount))
        legend('-DynamicLegend')
    end

    hold off

    savefig([figureOutputPathPrefix '8.fig'])  
    
    % Plot conditioning channel 
    timeAxis = 1000 / samplingFrequency : 1000 / samplingFrequency : 1000 * (size(mEP,1)) / samplingFrequency; % X-axis is in milliseconds
    timeAxis = timeAxis - PRE; % Forces t = 0 at trigger location
    singleChannelMEP = mEP(:, conditioningChannelId); 
    meanOfPreSweep = mean(singleChannelMEP(146 : 195) / AMPLIFICATION); % Mean of last 5 ms in the PreSweep before the stimulus artifact
    
    meanOfBaseline = mean(singleChannelMEP(1 : 150) / AMPLIFICATION); % Mean of samples between 5 ms and 20 ms pre stimulus 
    stdOfBaseline = std(singleChannelMEP(1 : 150) / AMPLIFICATION); % Standard deviation of samples between 5 ms and 20 ms pre stimulus 
    twoStdAboveBaseline = meanOfBaseline + 2 * stdOfBaseline;
    
    figure
    plot(timeAxis, singleChannelMEP / AMPLIFICATION, 'Color', [0, 0, 1], 'LineWidth', 1.5', 'Displayname', sprintf('''%s''', PHASE))
    ylim([-50 50])
    xlabel({'Time (ms)'}, 'FontSize', 12, 'FontWeight', 'bold')
    ylabel({'Amplitude (�V)'}, 'FontSize', 12, 'FontWeight', 'bold')
    title(sprintf('Recorded ''%s'' at %0.1f �A delivered to ''%s''; Recording site = ''%s''', EP, CURRENT, SITE, channelNames{conditioningChannelId}))
    legend('show')

    savefig([figureOutputPathPrefix '9.fig'])

    
end
